拆解大语言模型RLHF中的PPO算法
policy_model = load_model()
for k in range(20000):
# 采样(生成答案)
prompts = sample_prompt()
data = respond(policy_model, prompts)
# 反馈(计算奖励)
rewards = reward_func(reward_model, data)
# 学习(更新参数)
for epoch in range(4):
policy_model = train(policy_model, prompts, data, rewards)
采样
什么是“收益”呢?简单来说就是从下一个 token 开始,模型能够获得的总奖励(浮点数标量)。这里说的奖励包括 Reward Model 给出的奖励。奖励是怎么给的,以及收益有什么用,这些内容我们后面会详细介绍。
▲ policy模型结构
从实现上说,评论家就是将演员模型的倒数第二层连接到一个新的全连接层上。除了这个全连接层之外,演员和评论家的参数都是共享的(如上图)。
上面提到的模型结构是较早期的版本,后续不共享参数的实现方式也有很多。
现在我们来看看 PPO 的采样过程中有哪些模型和变量。如下图,矩形表示模型,椭圆表示变量。
▲ 采样流程(转载须引用)
图中的“old_policy”矩形就是刚刚说的 policy(为啥有个“old”前缀呢?后面我会详细解释)。
计算 response 的第 1 个 token 的概率分布,然后从概率分布中采样出第 1 个 token 根据第 1 个 token,计算 response 的第 2 个 token 的概率分布,然后从概率分布中采样出第 2 个 token …… 根据前 N-1 个 token,计算 response 的第 N 个 token 的概率分布,然后从概率分布中采样出第 N 个 token
然后就得到了三个输出。假设对每个 prompt,policy 生成的 token 的个数为 N,那么这三个输出分别是:
response:M 个字符串,每个字符串包含 N 个 token
old_log_probs:演员输出的 M × N 的张量,包含了 response 中 token 的对数概率 log(p(token|context))
old_values:评论家输出的 M × N 的张量,包含了每次生成 token 时评论家预估的收益
得到这三个输出后,采样阶段就就结束了。这三个输出都是后续阶段重要的输入数据。
我们先将采样部分的伪代码更新一下:
# 采样
prompts = sample_prompt()
responses, old_log_probs, old_values = respond(policy_model, prompts)
反馈
来理解一下这个式子:
ref_log_prob[i] 越高,ref_policy 越认可 old_policy 的输出,说明 old_policy 更守规矩,因此应该获得更高的奖励;
old_log_prob[i] 越高,old_policy 获得的奖励反而更低。old_log_prob[i] 作为正则项,可以保证概率分布的多样性。
# 采样
prompts = sample_prompt()
responses, old_log_probs, old_values = respond(policy_model, prompts)
# policy_model的副本,不更新参数
ref_policy_model = policy_model.copy()
# 反馈
scores = reward_model(prompts, responses)
ref_log_probs = analyze_responses(ref_policy_model, prompts, responses)
rewards = reward_func(reward_model, scores, old_log_probs, ref_log_probs)
学习
“学习”就是学生根据反馈总结得失并自我改进的过程,或者说是强化优势动作的过程。
如果说前两步分别是在收集数据 X,以及给数据打上标签 Y。那么这一步就是在利用数据 (X, Y) 训练模型。
"强化优势动作"是 PPO 学习阶段的焦点。在深入探讨之前,我们首先要明确一个关键概念——优势。
此处,我们将优势定义为“实际获得的收益超出预期的程度”。
为了解释这个概念,请允许我举一个例子。假设一个高中生小明,他在高一时数学考试的平均分为 100 分,在此之后,大家对他的数学成绩的预期就是 100 分了。到了高二,他的数学平均分提升到了 130 分。在这个学期,小明的数学成绩显然是超出大家的预期的。
表现是可用分数量化的,故表现超出预期的程度也是可以用分数差来量化的。我们可以认为,在高二阶段,小明超出预期的程度为 30 分(130 - 100)。根据优势的定义我们可以说,在高二阶段,小明相对于预期获得了 30 分的优势。
在这个例子中,实际已经给出了 PPO 计算优势的方法:优势 = 实际收益 - 预期收益。
对于语言模型而言,生成第 i 个 token 的实际收益就是:从生成第 i 个 token 开始到生成第 N 个 token 为止,所能获得的所有奖励的总和。我们用 return 来表示实际收益,它的计算方式如下:
好的,我们已经理解了优势的含义了。现在终于可以揭开这个关键主题的面纱——在 PPO 学习阶段,究竟什么是"强化优势动作"。
所谓“强化优势动作”,即强化那些展现出显著优势的动作。
在上面的小明的例子中,这意味着在高三阶段,小明应该持续使用高二的学习方法,因为在高二阶段,他的学习策略展示出了显著的优势。
在语言模型中,根据上下文生成一个 token 就是所谓的“动作”。"强化优势动作"表示:如果在上下文(context)中生成了某个 token,并且这个动作的优势很高,那么我们应该增加生成该 token 的概率,即增加 p(token|context) 的值。
由于 policy 中的演员模型建模了 p(token|context),所以我们可以给演员模型设计一个损失函数,通过优化损失函数来实现“强化优势动作”:
当优势大于 0 时,概率越大,loss 越小;因此优化器会通过增大概率(即强化优势动作)来减小 loss 当优势小于 0 时,概率越小,loss 越小;因此优化器会通过减小概率(即弱化劣势动作)来减小 loss
这很像巴浦洛夫的狗不是吗?
▲ 巴浦洛夫的狗
优势的绝对值越大,loss 的绝对值也就越大 优势是不接收梯度回传的
实际上,式 5 只是一个雏形。PPO 真正使用的演员的损失函数是这样的:
* 写给熟悉 RL 的人:简单起见,在这里我们既不考虑损失的截断,也不考虑优势的白化。
现在的问题就是,我们应该使用 p 的哪个老版本。还记得我们在本文开头时给出的伪代码吗(后来在介绍“采样”和“反馈”阶段时又各更新了一次),我们对着代码来解释:
policy_model = load_model()
ref_policy_model = policy_model.copy()
for k in range(20000):
# 采样(已更新)
prompts = sample_prompt()
responses, old_log_probs, old_values = respond(policy_model, prompts)
# 反馈(已更新)
scores = reward_model(prompts, responses)
ref_log_probs = analyze_responses(ref_policy_model, prompts, responses)
rewards = reward_func(reward_model, scores, old_log_probs, ref_log_probs)
# 学习
for epoch in range(4):
policy_model = train(policy_model, prompts, responses, old_log_probs, old_values, rewards)
简单来说,这段代码做的事情是:迭代 2 万次。在每次迭代中,通过采样和反馈得到一份数据,然后在学习阶段使用数据微调语言模型。每份数据我们都拿来训练 4 个 epoch。
至此,我们完整地描述了 PPO 的学习阶段中“强化优势动作”的方法。就像下面的计算图展示的那样(policy 与前面的图中的 old_policy 不一样,是实时版本的模型)。
▲ 学习流程(转载须引用)
等等,似乎还没完。图中还有一个叫 critic_loss 的没提到过的东西。
当然了,负责决策的演员需要学习,难道总结得失的评论家就不需要学习了?评论家也是需要与时俱进的嘛,否则画评家难道不怕再次错过梵高那样的天才?
既然如此,就设计一个损失函数来衡量评论家预期收益和真实收益之间的差距。
PPO 用的是均方差损失(MSE):
* 写给熟悉 RL 的人:由于我们不考虑 GAE,所以 returns 的计算也做了相应的简化。
最终优化 policy 时用的 loss 是演员和评论家的 loss 的加权和:
这才算是真正完事儿了。现在我们将整个 PPO 的伪代码都更新一下:
policy_model = load_model()
ref_policy_model = policy_model.copy()
for k in range(20000):
# 采样
prompts = sample_prompt()
responses, old_log_probs, old_values = respond(policy_model, prompts)
# 反馈
scores = reward_model(prompts, responses)
ref_log_probs, _ = analyze_responses(ref_policy_model, prompts, responses)
rewards = reward_func(reward_model, scores, old_log_probs, ref_log_probs)
# 学习
for epoch in range(4):
log_probs, values = analyze_responses(policy_model, prompts, responses)
advantages = advantage_func(rewards, old_values)
actor_loss = actor_loss_func(advantages, old_log_probs, log_probs)
critic_loss = critic_loss_func(rewards, values)
loss = actor_loss + 0.1 * critic_loss
train(loss, policy_model.parameters())
总结
到这里,大语言模型 RLHF 中 PPO 算法的完整细节就算介绍完了。掌握这些细节之后,我们可以做的有趣的事情就变多了。例如:
你可以照着伪代码从头到尾自己实现一遍,以加深理解。相信我,这是非常有趣且快乐的过程
你可以以此为契机,把强化学习知识系统性地学一遍。你会发现很多强化学习的概念一下变得具象化了
你可以在你的产品或者研究方向中思考 PPO 是否可以落地
你也许会发现 PPO 算法的不合理之处,那么就深入研究下去,直到做出自己的改进
你可以跟周围不熟悉 PPO 的小伙伴吹牛,顺便嘲讽对方(大误)
总之,希望我们都因为掌握了知识变得更加充实和快乐~
更多阅读
#投 稿 通 道#
让你的文字被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析、科研心得或竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。
📝 稿件基本要求:
• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注
• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题
• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算
📬 投稿通道:
• 投稿邮箱:hr@paperweekly.site
• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者
• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿
△长按添加PaperWeekly小编
🔍
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧